import cv2
import numpy as np
import onnxruntime

class YOLOv7:

    def __init__(
            self, 
            model_path, 
            labels_path,
            engine_path, 
            official_nms=False
        ):
        self.official_nms = official_nms
        
        self.class_names = []
        with open(labels_path, 'r') as f:
            self.class_names = [cname.strip() for cname in f.readlines()]
        f.close()

        # Create a list of colors for each class where each color is a tuple of 3 integer values
        rng = np.random.default_rng(3)
        self.colors = rng.uniform(0, 255, size=(len(self.class_names), 3))

        # Initialize model
        self.initialize_model(model_path, engine_path)

    def __call__(self, image, confidence_threshold, nms_threshold):
        return self.detect_objects(image, confidence_threshold, nms_threshold)

    def xywh2xyxy(self, x):
        # Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
        y = np.copy(x)
        y[..., 0] = x[..., 0] - x[..., 2] / 2
        y[..., 1] = x[..., 1] - x[..., 3] / 2
        y[..., 2] = x[..., 0] + x[..., 2] / 2
        y[..., 3] = x[..., 1] + x[..., 3] / 2
        return y

    def initialize_model(self, model_path, engine_path):
        self.session = onnxruntime.InferenceSession(
            model_path,
            providers=[
                # (
                #     'TensorrtExecutionProvider',
                #     {
                #         'device_id': 0,
                #         'trt_max_workspace_size': 2147483648,
                #         'trt_fp16_enable': True,
                #         'trt_engine_cache_enable': True,
                #         'trt_engine_cache_path': '{}'.format(engine_path),
                #     }
                # ),
                # (
                #     'CUDAExecutionProvider', 
                #     {
                #         'device_id': 0,
                #         'arena_extend_strategy': 'kNextPowerOfTwo',
                #         'gpu_mem_limit': 2 * 1024 * 1024 * 1024,
                #         'cudnn_conv_algo_search': 'EXHAUSTIVE',
                #         'do_copy_in_default_stream': True,
                #     }
                # )
                'CPUExecutionProvider'
            ]
        )
        # Get model info
        self.get_input_details()
        self.get_output_details()

        self.has_postprocess = 'score' in self.output_names or self.official_nms


    def detect_objects(self, image, confidence_threshold, nms_threshold):
        input_tensor = self.prepare_input(image)

        # Perform inference on the image
        outputs = self.inference(input_tensor)

        # Process output data
        self.boxes, self.scores, self.class_ids = self.process_output(outputs, confidence_threshold, nms_threshold)

        return self.boxes, self.scores, self.class_ids

    def prepare_input(self, image):
        self.img_height, self.img_width = image.shape[:2]

        input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Resize input image
        input_img = cv2.resize(input_img, (self.input_width, self.input_height))

        # Scale input pixel values to 0 to 1
        input_img = input_img / 255.0
        input_img = input_img.transpose(2, 0, 1)
        input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)

        return input_tensor
    
    def rescale_boxes(self, boxes):

        # Rescale boxes to original image dimensions
        input_shape = np.array([self.input_width, self.input_height, self.input_width, self.input_height])
        boxes = np.divide(boxes, input_shape, dtype=np.float32)
        boxes *= np.array([self.img_width, self.img_height, self.img_width, self.img_height])
        return boxes
    
    def process_output(self, output, conf_threshold, nms_threshold):
        boxes, scores, class_ids = output
        boxes = boxes[0]
        scores = scores[0]
        class_ids = class_ids[0]

        res_boxes = []
        res_scores = []
        res_class_ids = []

        for box, score, class_id in zip(boxes, scores, class_ids):
            if score > conf_threshold:
                score = score[0]
                res_boxes.append(box)
                res_scores.append(score)
                res_class_ids.append(int(class_id))
        
        if len(res_scores) == 0:
            return [], [], []
        
        # Scale boxes to original image dimensions
        res_boxes = self.rescale_boxes(res_boxes)

        fin_boxes, fin_scores, fin_class_ids = [], [], []
        final_boxes = cv2.dnn.NMSBoxes(res_boxes, res_scores, conf_threshold, nms_threshold)
        for max_valueid in final_boxes:
            fin_boxes.append(res_boxes[max_valueid])
            fin_scores.append(res_scores[max_valueid])
            fin_class_ids.append(res_class_ids[max_valueid])
        
        # Convert boxes to xyxy format
        fin_boxes = self.xywh2xyxy(np.array(fin_boxes))
        
        # Convert class ids to class names
        fin_class_ids = [self.class_names[i] for i in fin_class_ids]
        return fin_boxes, fin_scores, fin_class_ids

    def draw_detections(self, image, draw_scores=True, mask_alpha=0.4):
        
        mask_img = image.copy()
        det_img = image.copy()

        img_height, img_width = image.shape[:2]
        size = min([img_height, img_width]) * 0.0006
        text_thickness = int(min([img_height, img_width]) * 0.001)

        # Draw bounding boxes and labels of detections
        for box, score, class_id in zip(self.boxes, self.scores, self.class_ids):
            color = self.colors[class_id]

            x1, y1, x2, y2 = box.astype(int)

            # Draw rectangle
            cv2.rectangle(det_img, (x1, y1), (x2, y2), color, 2)

            # Draw fill rectangle in mask image
            cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1)

            label = self.class_names[class_id]
            caption = f'{label} {int(score * 100)}%'
            (tw, th), _ = cv2.getTextSize(text=caption, fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                                        fontScale=size, thickness=text_thickness)
            th = int(th * 1.2)

            cv2.rectangle(det_img, (x1, y1),
                        (x1 + tw, y1 - th), color, -1)
            cv2.rectangle(mask_img, (x1, y1),
                        (x1 + tw, y1 - th), color, -1)
            cv2.putText(det_img, caption, (x1, y1),
                        cv2.FONT_HERSHEY_SIMPLEX, size, (255, 255, 255), text_thickness, cv2.LINE_AA)

            cv2.putText(mask_img, caption, (x1, y1),
                        cv2.FONT_HERSHEY_SIMPLEX, size, (255, 255, 255), text_thickness, cv2.LINE_AA)

        return cv2.addWeighted(mask_img, mask_alpha, det_img, 1 - mask_alpha, 0)

    def get_input_details(self):
        model_inputs = self.session.get_inputs()
        self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]

        self.input_shape = model_inputs[0].shape
        self.input_height = self.input_shape[2]
        self.input_width = self.input_shape[3]

    def get_output_details(self):
        model_outputs = self.session.get_outputs()
        self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]

    def inference(self, input_tensor):
        outputs = self.session.run(self.output_names, {self.input_names[0]: input_tensor})
        return outputs
